"""Main entry point for doing all pruning-related stuff. Adapted from https://github.com/arunmallya/packnet/blob/master/src/main.py"""
from __future__ import division, print_function

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import sys
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pickle
import numpy as np
import time
import art.attacks.evasion
import mi_estimator
import warnings
# To prevent PIL warnings.
warnings.filterwarnings("ignore")
import pytorch_lightning
from torchmetrics import Accuracy

import data
from torch.autograd import Variable
from tqdm import tqdm
import torchnet as tnt
from art.estimators.classification import PyTorchClassifier
from torchsummary import summary
import utils


######################################################################################################################################################################
###
###     Main function
###
######################################################################################################################################################################


class Manager(object):
    """Handles training and pruning."""

    def __init__(self, args, model, trainloader, testloader, advtrainloader, advtestloader):
        self.model = model
        self.args=args
        self.train_data_loader = trainloader
        self.test_data_loader = testloader
        self.adv_train_data_loader = advtrainloader
        self.adv_test_data_loader = advtestloader
        self.criterion = nn.CrossEntropyLoss()
        self.baseline_acts = []
        self.lincom_acts = []
        self.loss = []



    def eval(self, biases=None, adversarial=False, data="Testing", lincom=False):
        """Performs evaluation."""
        self.model.eval()
        self.model.lincom=lincom 

        error_meter = None
        
        if data=="Testing" and adversarial==True:
            dataloader = self.adv_test_data_loader
        elif data=="Training" and adversarial==True:
            dataloader = self.adv_train_data_loader
        elif data=="Testing" and adversarial==False:
            dataloader = self.test_data_loader
        else:
            print("Using Training, Non-adversarial Data for Evaluation")
            dataloader = self.train_data_loader
        
        if lincom==True:
            self.model.lincom1.reset_xs()
            
        print('Performing eval...')
        if adversarial==True:
            for batch, label in tqdm(dataloader, desc='Eval'):
                batch = batch.cuda()
                batch = Variable(batch, volatile=True)
    
                output = self.model(batch) 
    
                # Init error meter.
                if error_meter is None:
                    topk = [1]
                    if output.size(1) > 5:
                        topk.append(5)
                    error_meter = tnt.meter.ClassErrorMeter(topk=topk)
                error_meter.add(output.data, label)
        else:
            for batch, label in tqdm(dataloader, desc='Eval'):
                batch = batch.cuda()
                batch = Variable(batch, volatile=True)
    
                output = self.model(batch) 
                # Init error meter.
                if error_meter is None:
                    topk = [1]
                    if output.size(1) > 5:
                        topk.append(5)
                    error_meter = tnt.meter.ClassErrorMeter(topk=topk)
                error_meter.add(output.data, label)


        errors = error_meter.value()
        print('Error: ' + ', '.join('@%s=%.2f' %
                                    t for t in zip(topk, errors)))
        self.model.train()
        return errors
    



    def do_epoch(self, epoch_idx, optimizer, adversarial=False):
        """Trains model for one epoch."""
        if adversarial==True:
            for batch, label in tqdm(self.adv_train_data_loader, desc='Epoch: %d ' % (epoch_idx)):
                """Runs model for one batch."""
                batch = batch.cuda()
                label = label.cuda()
                batch = Variable(batch)
                label = Variable(label)
        
                # Set grads to 0.
                self.model.zero_grad()
        
                # Do forward-backward.
                output = self.model(batch)
                self.criterion(output, label).backward()
        
                # Update params.
                optimizer.step()
        else:
            for batch, label in tqdm(self.train_data_loader, desc='Epoch: %d ' % (epoch_idx)):
                """Runs model for one batch."""
                batch = batch.cuda()
                label = label.cuda()
                batch = Variable(batch)
                label = Variable(label)
        
                # Set grads to 0.
                self.model.zero_grad()
        
                # Do forward-backward.
                output = self.model(batch)
                self.criterion(output, label).backward()
        
                # Update params.
                optimizer.step()
        
    

    ### Outputs: Post-Epoch Accuracy
    ### Stores: Nothing
    def train(self, epochs, optimizer, save=False, target_accuracy=0, best_accuracy=0, adversarial=False, eps=0):
        """Performs training."""
        best_accuracy = best_accuracy
        error_history = []

        best_train_accuracy = 0
        train_error_history = []
        target_accuracy = target_accuracy
        # print("Best Accuracy")
        patience = 3
        base_path = ("./saves/" + self.args.dataset + "/" + self.args.network + "/" + self.args.attacktype)
        os.makedirs(base_path, exist_ok=True)    
        checkpoint_path = (base_path + "/checkpoint")

        self.model = self.model.cuda()
        
        
        for idx in range(epochs):
            epoch_idx = idx + 1
            print('Epoch: %d' % (epoch_idx))

            self.model.train()
            
            self.do_epoch(epoch_idx, optimizer, adversarial=adversarial)
            
            errors = self.eval(adversarial=adversarial, data="Testing", lincom=False)
            error_history.append(errors)
            accuracy = 100 - errors[0]  # Top-1 accuracy.
            
            if accuracy >= best_accuracy:
                self.save_model(checkpoint_path)
                print('Best model so far, Accuracy: %0.2f%% -> %0.2f%%' %(best_accuracy, accuracy))                
                best_accuracy=accuracy
            elif patience <= 0:  
                self.load_model(torch.load(checkpoint_path))
                patience = 3
            else:
                patience -= 1
            

        print('Finished finetuning...')
        print('Best error/accuracy: %0.2f%%, %0.2f%%' %(100 - best_accuracy, best_accuracy))
        print('-' * 16)
        return best_accuracy
    
    
    
    
    #####################################################################################################################################
    ###    Linear Combination Functions
    #####################################################################################################################################
    
    
    
    
    ### Outputs: Accuracy
    ### Stores: 
    ###        f_b* acts as baseline_acts
    ###        f_b acts as lincom_acts
    ###        loss as frobenius norm of lincom_acts-baseline_acts  
    def evalLincom(self, biases=None, store=False, lincom=True, adversarial=False, data="Testing"):
        """Performs evaluation."""
        self.model.eval()
        self.model.lincom=lincom
        self.output = []
        if data=="Testing" and adversarial==True:
            dataloader = self.adv_test_data_loader
        elif data=="Training" and adversarial==True:
            dataloader = self.adv_train_data_loader
        elif data=="Testing" and adversarial==False:
            dataloader = self.test_data_loader
        else:
            print("Using Training, Non-adversarial Data for Evaluation")
            dataloader = self.train_data_loader
        
        print('Performing eval...')
        batchnum=0
        if lincom==True:
            self.model.lincom1.reset_xs()
        
        for batch, label in tqdm(dataloader, desc='Eval'):
            batch = batch.cuda()
            batch = Variable(batch, volatile=True)

            ### Previously labels are an Nx1 array, outputs are an Nxbs array which gets processed into Nxtopk and compared to label via argmax, more or less
            output = self.model(batch)
            temp_acts = output
            temp_acts = temp_acts.detach().cpu().numpy()

            if lincom == False:
                if batchnum == 0:
                    self.baseline_acts = []
                    loss = []
                self.baseline_acts.append(temp_acts)
                loss.append(0)
            else:
                if batchnum == 0:
                    self.lincom_acts = []
                    loss=[]
                self.lincom_acts.append(temp_acts)
               
                ptbaseline = torch.from_numpy(self.baseline_acts[batchnum])
                ptlincom = torch.from_numpy(self.lincom_acts[batchnum])
                
                loss.append(torch.linalg.matrix_norm(ptbaseline-ptlincom))

            batchnum += 1
        print("Nans in baseline_acts: ", np.isnan(np.sum(np.asarray(self.baseline_acts))))
        self.model.train()
        return loss
        
        
    
    
    

    def update_lambdas(self):
        """Trains model for one epoch."""
        ### A dummy update which should quickly cause accuracy to deteriorate to check when updates are occuring
        if self.args.avg == "True":
          print("Using avg")
          batch_lambdas = []
          for i in range(0,len(self.baseline_acts)):
            xs_batch = self.model.lincom1.xs[i]
            baseline_acts_batch = torch.from_numpy(self.baseline_acts[i]).cuda()
            xs_batch_inv = torch.linalg.pinv(xs_batch)
            
            l_batch = torch.mm(xs_batch_inv.type(torch.float64),baseline_acts_batch.type(torch.float64))
            batch_lambdas.append(l_batch.detach().cpu().numpy())
          l_optimal = np.mean(np.asarray(batch_lambdas),axis=0)
        else:
          print("Taking full dataset inverse")
          full_baseline_acts = []
          full_xs = []
          for i in range(0,len(self.baseline_acts)):
            a = self.model.lincom1.xs[i].detach().cpu().numpy()
            b = self.baseline_acts[i]
            if i == 0:
              full_xs = a
              full_baseline_acts = b
            else:
              full_xs = np.concatenate((full_xs,a), axis=0)
              full_baseline_acts = np.concatenate((full_baseline_acts,b), axis=0)
          
          full_xs = torch.from_numpy(full_xs)
          full_baseline_acts = torch.from_numpy(full_baseline_acts)
          pinv = torch.linalg.pinv(full_xs)
          print("pinv size: ", pinv.size())
          print("full_baseline_acts size: ", full_baseline_acts.size())
          l_optimal = torch.mm(pinv.type(torch.float64),full_baseline_acts.type(torch.float64))
          l_optimal = l_optimal.detach().cpu().numpy()
        self.model.lincom1.lambdas = l_optimal
    
    
    
    

    def train_lincom(self, save=False, target_accuracy=0, best_accuracy=0, adversarial=False, eps=0):
        """Performs training."""
        self.model.lincom = True      
        self.model = self.model.cuda()
        
        error_history = []

        target_accuracy = target_accuracy

        
        self.model.eval()
        
        ### Calculate loss based on activations and update lambdas
        loss = self.evalLincom(store=False, adversarial=adversarial, data="Testing", lincom=True)
        self.update_lambdas()
        self.loss = loss
        ### Evaluate accuracy and check early stopping criteria
        errors = self.eval(adversarial=adversarial, data="Testing", lincom=True)
        error_history.append(errors)
        accuracy = 100 - errors[0]  # Top-1 accuracy.
        

        self.model.train()

        print('Finished finetuning...')
        print('Accuracy: %0.2f%%, %0.2f%%' %(100 - accuracy, accuracy))
        print('-' * 16)
        return accuracy
    
    
        
    
    
    
    
    
    
    
    


    def load_model(self,state_dict, f_b_start=0):
        print(f_b_start)
        if f_b_start > 0:
            with torch.no_grad():
                for name, module in enumerate(self.model.named_modules()):
                    if name in [5,9,12,16,19,22,26,29,32,36,39,42,48,51,54] and name >= f_b_start:
                        # print(name," ", module[0])
                        module[1].weight.copy_(state_dict[(module[0] + ".weight")])
                        module[1].bias.copy_(state_dict[(module[0] + ".bias")])
        else:
            self.model.load_state_dict(state_dict)
            
    def save_model(self, path):
        print("Saving model to path: ", path)
        torch.save(self.model.state_dict(), path)
        
    def load_model(self,state_dict, f_b_start=0):
        print("f_b_start: ", f_b_start)
        if f_b_start > 0:
            with torch.no_grad():
                for name, module in enumerate(self.model.named_modules()):
                    # print(name," ", module[0]," ",module[1])
                    # print(name, " ",module[0])
                    if name >= f_b_start:
                        if isinstance(module[1], nn.BatchNorm2d):
                            # print(name, " ",module[0])
                            module[1].weight.copy_(state_dict[(module[0] + ".weight")])
                            module[1].bias.copy_(state_dict[(module[0] + ".bias")])
                            module[1].running_mean.copy_(state_dict[(module[0] + ".running_mean")])
                            module[1].running_var.copy_(state_dict[(module[0] + ".running_var")])
                            module[1].num_batches_tracked.copy_(state_dict[(module[0] + ".num_batches_tracked")])
                        elif isinstance(module[1], nn.Conv2d):
                            module[1].weight.copy_(state_dict[(module[0] + ".weight")])
                            # print(name, " ",module[0]," ",module[1])
                            if self.args.network == "VGG16" or self.args.network == "AlexNet":
                                module[1].bias.copy_(state_dict[(module[0] + ".bias")])
                        elif isinstance(module[1], nn.Linear):
                            module[1].weight.copy_(state_dict[(module[0] + ".weight")])
                            module[1].bias.copy_(state_dict[(module[0] + ".bias")])
                            # print(name, " ",module[0], " ",module[1])
        else:
            self.model.load_state_dict(state_dict)
            
    def freeze_fa(self, f_b_start):
        print("unfrozen layers:")
        for name, module in enumerate(self.model.named_modules()):
            if name < f_b_start:
                if isinstance(module[1], nn.Conv2d) or isinstance(module[1], nn.Linear) or isinstance(module[1], nn.BatchNorm2d):
                    # print(name," ", module[0])
                    for param in module[1].parameters():
                        param.requires_grad = False
            else:
                print(name," ", module[0])
